#!/usr/bin/env python3

import collections
import csv
import os
import itertools
from pprint import pprint
import shutil
import sys

import numpy as np
import pandas as pd

QUACLIB = os.path.abspath(os.path.dirname(__file__) + '/../../lib')
sys.path.insert(0, QUACLIB)
import u

l = u.l
u.logging_init('dfpdl')

CSVPARMS = { 'sep': '\t' }

Q_VLOW =  0.01
Q_LOW =   0.05
Q_MID =   0.50
Q_HIGH =  0.95
Q_VHIGH = 0.99

STALE_MAX_WIN = 25
STALE_MAX_FAIL = 0

resultname = sys.argv[1]
basename = os.path.basename(resultname.split('.')[0])
basedir = os.path.abspath(os.path.dirname(resultname))
dirname = basename
(location, dd) = basename.split('+')
(disease, distance) = dd.split(',')
distance = int(distance)
outbreak = '%s+%s' % (location, disease)

l.info('starting in %s' % basedir)
l.info('outbreak: %s+%s,%d' % (location, disease, distance))

class Namespace(object): pass
input_ = u.pickle_load('%s/input,%d.pkl.gz' % (basedir, distance))
truth = input_.truth[outbreak]
l.debug('loaded input and truth')

data_raw = u.pickle_load(sys.argv[1])
l.debug('loaded outbreak %s' % basename)

os.chdir(dirname)

# truth
assert (truth.name == outbreak)
truth.to_csv('truth.tsv', **CSVPARMS)

# Compute various prediction summaries
data = dict()  # horizon, training, results tuple
Result_Types = collections.namedtuple('Result_Types', [
   'all',      # all predictions, DataFrame: index periods, columns nows
   'fresh',    # freshest predictions, DataFrame:
               #   index: nows, columns: date, predictions by staleness
   'fresh_d',  # as above, index: dates
])
horizons = set()
trainings = set()
for (h, hdata) in data_raw.items():
   horizons.add(h)
   data[h] = dict()
   hdir = 'h%d' % h
   u.mkdir_f(hdir)
   for (t, all_) in hdata.items():
      trainings.add(t)
      # Ensure a continous sequence of columns.
      all_ = all_.reindex(columns=range(all_.columns.min(),
                                        all_.columns.max() + 1))
      all_.to_csv('%s/t%d.pred-all.tsv' % (hdir, t), **CSVPARMS)
      fresh = pd.DataFrame(index=sorted(all_.columns),
                           columns=['date'] + list(range(STALE_MAX_WIN+1)))
      fresh.loc[:,'date'] = all_.index.to_timestamp(how='start')
      # Set dtype explicitly to avoid AttributeError in corr() later.
      # http://stackoverflow.com/a/18316030/396038
      fresh_d = pd.DataFrame(index=all_.index, columns=range(STALE_MAX_WIN+1),
                             dtype=np.float64)
      for stale in range(STALE_MAX_WIN + 1 if (h == 0) else STALE_MAX_FAIL + 1):
         for col_i in range(len(all_.columns)):
            row_i = col_i + stale
            if (row_i >= len(all_)):
               break
            #date = all_.index[row_i].to_timestamp(how='start')
            now = all_.columns[col_i]
            pred = all_.iloc[row_i,col_i]
            fresh.iloc[col_i+stale,stale+1] = pred
            fresh_d.iloc[col_i+stale,stale] = pred
      fresh.dropna(axis=1, how='all', inplace=True)
      fresh_d.dropna(axis=1, how='all', inplace=True)
      fresh.to_csv('%s/t%d.pred-first.tsv' % (hdir, t), **CSVPARMS)
      data[h][t] = Result_Types(all_, fresh, fresh_d)
      l.info('horizon %d, training %d done' % (h, t))

horizons = sorted(horizons)
trainings = sorted(trainings)

# Compute r^2 for each horizon, training at staleness 0
rsq = pd.DataFrame(index=horizons, columns=trainings)
for (h, hdata) in data.items():
   for (t, r) in hdata.items():
      rsq.loc[h,t] = r.fresh_d.loc[:,0].corr(truth) ** 2
rsq.to_csv('rsquared-h.tsv', **CSVPARMS)
rsq.T.to_csv('rsquared-t.tsv', **CSVPARMS)

# Compute r^2 for each staleness, training, horizon
for (h, hdata) in data.items():
   hdir = 'h%d' % h
   rsq = pd.DataFrame(index=range(STALE_MAX_WIN+1), columns=trainings)
   for (t, r) in hdata.items():
      for s in r.fresh_d.columns:
         rsq.loc[s,t] = r.fresh_d.loc[:,s].corr(truth) ** 2
   rsq.dropna(axis=0, how='all', inplace=True)
   rsq.to_csv('%s/rsquared.tsv' % hdir, **CSVPARMS)

l.info('done')
